{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fast.ai lesson 1 - training on Notebook Instance and export to torch.jit model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "This notebook shows how to use the SageMaker Python SDK to train your fast.ai model on a SageMaker notebook instance then export it as a torch.jit model to be used for inference on AWS Lambda.\n",
    "\n",
    "## Set up the environment\n",
    "\n",
    "You will need a Jupyter notebook with the `boto3` and `fastai` libraries installed. You can do this with the command `pip install boto3 fastai`\n",
    "\n",
    "This notebook was created and tested on a single ml.p3.2xlarge notebook instance. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your model\n",
    "\n",
    "We are going to train a fast.ai model as per [Lesson 1 of the fast.ai MOOC course](https://course.fast.ai/videos/?lesson=1) locally on the SageMaker Notebook instance. We will then save the model weights and upload them to S3.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import tarfile\n",
    "\n",
    "import PIL\n",
    "\n",
    "import boto3\n",
    "\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS); path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_anno = path/'annotations'\n",
    "path_img = path/'images'\n",
    "fnames = get_image_files(path_img)\n",
    "np.random.seed(2)\n",
    "pat = re.compile(r'/([^/]+)_\\d+.jpg$')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=64\n",
    "img_size=299"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),\n",
    "                                   size=img_size, bs=bs//2).normalize(imagenet_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(data, models.resnet50, metrics=error_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.lr_find()\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export model and upload to S3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have trained our model we need to export it, create a tarball of the artefacts and upload to S3.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need to export the class names from the data object into a text file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_texts(path_img/'models/classes.txt', data.classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we need to export the model in the [PyTorch TorchScript format](https://pytorch.org/docs/stable/jit.html) so we can load into an AWS Lambda function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trace_input = torch.ones(1,3,img_size,img_size).cuda()\n",
    "jit_model = torch.jit.trace(learn.model.float(), trace_input)\n",
    "model_file='resnet50_jit.pth'\n",
    "output_path = str(path_img/f'models/{model_file}')\n",
    "torch.jit.save(jit_model, output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next step is to create a tarfile of the exported classes file and model weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tar_file=path_img/'models/model.tar.gz'\n",
    "classes_file='classes.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tarfile.open(tar_file, 'w:gz') as f:\n",
    "    f.add(path_img/f'models/{model_file}', arcname=model_file)\n",
    "    f.add(path_img/f'models/{classes_file}', arcname=classes_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we need to upload the model tarball to S3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3 = boto3.resource('s3')\n",
    "s3.meta.client.upload_file(str(tar_file), 'REPLACE_WITH_YOUR_BUCKET_NAME', 'fastai-models/lesson1/model.tar.gz')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
